package org.testfun.jee.runner.inject; import org.junit.Assert; import org.testfun.jee.runner.SingletonEntityManager; import javax.ejb.ApplicationException; import javax.persistence.EntityTransaction; import java.lang.reflect.InvocationHandler; import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; import java.lang.reflect.Proxy; public class TransactionUtils { public static Object wrapEjbWithTransaction(Object impl) { Assert.assertNotNull("EJB Implementation is null", impl); Class<?>[] interfaces = impl.getClass().getInterfaces(); return Proxy.newProxyInstance(TransactionUtils.class.getClassLoader(), interfaces, new TransactionalMethodWrapper(impl)); } public static boolean beginTransaction() { EntityTransaction tx = getTransaction(); if (!tx.isActive()) { tx.begin(); return true; } return false; } public static void rollbackTransaction() { EntityTransaction tx = getTransaction(); if (tx.isActive() && !tx.getRollbackOnly()) { // only flag the transaction for rollback - actual rollback will happen when the method starting the transaction is done tx.setRollbackOnly(); } } public static void endTransaction(boolean newTransaction) { EntityTransaction tx = getTransaction(); if (tx.isActive() && newTransaction) { if (tx.getRollbackOnly()) { // Rollback the transaction and close the connection if roll back was requested deeper in the stack and this is the method starting the transaction tx.rollback(); } else { // Commit the transaction only if this is the method starting the transaction and rollback wasn't requested tx.commit(); } } } private static EntityTransaction getTransaction() { return SingletonEntityManager.getInstance().getTransaction(); } private static class TransactionalMethodWrapper implements InvocationHandler { private Object delegate; private TransactionalMethodWrapper(Object delegate) { this.delegate = delegate; } public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { boolean newTransaction = beginTransaction(); try { return method.invoke(delegate, args); } catch (Throwable throwable) { if (throwable instanceof InvocationTargetException) { InvocationTargetException invocationTargetException = (InvocationTargetException) throwable; ApplicationException applicationException = invocationTargetException.getTargetException().getClass().getAnnotation(ApplicationException.class); if (applicationException == null || applicationException.rollback()) { rollbackTransaction(); } } else { rollbackTransaction(); } throw throwable.getCause(); } finally { endTransaction(newTransaction); } } } }